"""
    To train various models
"""
import os, csv, json
import random
import argparse
import numpy as np
from tqdm import tqdm

# torch modules
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.nn.parallel import DataParallel

# custom libs
from utils.datasets import load_dataset
from utils.networks import load_network, load_trained_network
from utils.optims import define_loss_function, define_optimizer

#utils
from utils.learners import train, valid, train_with_hessian, train_with_hessian_trace, train_with_adaHessian, train_with_hessian_layer_track, train_hessTrace


# ------------------------------------------------------------------------------
#   Globals
# ------------------------------------------------------------------------------
best_acc = 0
best_hess_loss = 10000
best_trace = 10000

# ------------------------------------------------------------------------------
#   Compute Hessian = 1000
# ------------------------------------------------------------------------------


def compute_hessian_vector_product(model, loss, data, labels, vector):
    # Set requires_grad to True for all model parameters
    for param in model.parameters():
        param.requires_grad = True

    # Forward pass
    output = model(data)
    loss_value = loss(output, labels)

    # Compute gradients
    grads = torch.autograd.grad(loss_value, model.parameters(), create_graph=True)

    # Flatten the gradients and vector for compatibility
    flat_grads = torch.cat([g.view(-1) for g in grads])
    flat_vector = vector.view(-1)

    # Compute Hessian-vector product
    hvp = torch.autograd.grad(flat_grads, model.parameters(), grad_outputs=flat_vector)

    # Reset requires_grad to False
    for param in model.parameters():
        param.requires_grad = False

    return hvp



def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

# ------------------------------------------------------------------------------
#   Run training
# ------------------------------------------------------------------------------
def run_training(args):

    # set if cuda is unavailable
    if not torch.cuda.is_available(): args.cuda = False

    # init. the random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda: torch.cuda.manual_seed(args.seed)

    # set the CUDNN backend as deterministic
    if args.cuda: cudnn.deterministic = True

    # init. dataset (train/test)
    kwargs = {
            'num_workers': args.num_workers,
            'pin_memory' : args.pin_memory
        } if args.cuda else {}
    train_loader, valid_loader = load_dataset( \
            args.dataset, args.datapth, args.batch_size, args.augment, kwargs)
    print (' : Load the dataset [{}] from [{}]'.format(args.dataset, args.datapth))

    # init. the network
    network = load_network(args.dataset, args.network)
    if args.trained:
        load_trained_network(network, args.cuda, args.trained)
    netname = type(network).__name__
    
    # # # Wrap the network model with DataParallel if multiple GPUs are available
    # if torch.cuda.device_count() > 1:
    #     print("Let's use", torch.cuda.device_count(), "GPUs!")
    #     network = DataParallel(network)
    
    if args.cuda: network.cuda(device=0)
    print (' : Define a network [{}]'.format(netname))

    # Print the names of the layers
    for name, layer in network.named_parameters():
        print(name)

    

    
    model = network # Replace YourModel with the actual model instance
    total_params = count_parameters(model)
    print(f'Total number of parameters: {total_params}')


    # Print the number of parameters for each layer
    for name, param in model.named_parameters():
        print(f"Layer: {name}, Parameters: {param.numel()}")

    

    # init. loss function
    task_loss = define_loss_function(args.lossfunc)
    print (' : Define a loss function [{}]'.format(args.lossfunc))

    # init. the optimizer
    optimizer, scheduler = define_optimizer( \
        network, args.optimizer, args.lr, args.schedule)
    print (' : Define an optimizer [{}] w. [{}]'.format(args.optimizer, args.lr))


    # init. output dirs [vanilla training]
    store_paths = {}
    store_paths['folder'], store_paths['prefix'] = define_vanilla_store_prefixes(args)
    store_paths['model']  = os.path.join('models', args.dataset, store_paths['folder'])
    if not os.path.isdir(store_paths['model']): os.makedirs(store_paths['model'])
    store_paths['result'] = os.path.join('results', args.dataset, store_paths['folder'])
    if not os.path.isdir(store_paths['result']): os.makedirs(store_paths['result'])
    print (' : Store the results under [{}], with the prefix [{}]'.format( \
        store_paths['prefix'], store_paths['model']))


    # Note: load the global vars
    global best_acc 
    global best_hess_loss
    patience = 10  # Define the patience parameter for early stopping
    early_stop_counter = 0  # Initialize the counter for early stopping
    global best_trace
    
    # training (only when there is no model)
    if args.hessian:
        for epoch in range(1, args.epoch+1):
            train_acc, train_loss, total_loss, hessian_loss = train_with_hessian( \
                args, epoch, network, train_loader, task_loss, optimizer, scheduler)
            valid_acc, valid_loss = valid( \
                args, epoch, network, valid_loader, task_loss, store_paths)
            
            # : store a record
            record_row = [epoch, train_acc, train_loss, valid_acc, valid_loss, best_acc, total_loss, hessian_loss]
            record_savefile = '{}.csv'.format(store_paths['prefix'])
            record_savefile = os.path.join(store_paths['result'], record_savefile)
            if epoch < 2 and os.path.exists(record_savefile): os.remove(record_savefile)
            csv_logger(record_row, record_savefile)

            # store the model when the attack is successful
            model_savefile = '{}.pth'.format(store_paths['prefix'])
            model_savefile = os.path.join(store_paths['model'], model_savefile)
            if valid_acc > best_acc:
                torch.save(network.state_dict(), model_savefile)
                print (' -> cur acc. [{:.4f}] > best acc. [{:.4f}], store the model.\n'.format(valid_acc, best_acc))
                best_acc = valid_acc
    elif args.hessian_layer_track:
        for epoch in range(1, args.epoch+1):
            train_acc, train_loss, total_loss, hessian_loss, avg_layer_hessian_values = train_with_hessian_layer_track( \
                args, epoch, network, train_loader, task_loss, optimizer, scheduler)
            valid_acc, valid_loss = valid( \
                args, epoch, network, valid_loader, task_loss, store_paths)
            
            print("Keys in avg_layer_hessian_values:", avg_layer_hessian_values.keys())
            # Unpack avg_layer_hessian_values dictionary into single variables for each layer
            conv1_w = avg_layer_hessian_values['conv1.weight']
            conv1_b = avg_layer_hessian_values['conv1.bias']
            conv1_hessian = ((conv1_w + conv1_b) / 2)
            conv2_w = avg_layer_hessian_values['conv2.weight']
            conv2_b = avg_layer_hessian_values['conv2.bias']
            conv2_hessian = ((conv2_w+conv2_b) /2)
            fc1_w = avg_layer_hessian_values['fc1.weight']
            fc1_b = avg_layer_hessian_values['fc1.bias']
            fc1_hessian = ((fc1_w+fc1_b) /2)
            fc2_w = avg_layer_hessian_values['conv1.weight']
            fc2_b = avg_layer_hessian_values['conv1.bias']
            fc2_hessian = ((fc2_w+fc2_b) /2)
           
        
            # Print the average Hessian values for each layer
            print("Epoch:", epoch)
            print("Average Hessian value for conv1 layer:", conv1_hessian)
            print("Average Hessian value for conv2 layer:", conv2_hessian)
            print("Average Hessian value for fc1 layer:", fc1_hessian)
            print("Average Hessian value for fc2 layer:", fc2_hessian)
            
            # : store a record
            record_row = [epoch, train_acc, train_loss, valid_acc, valid_loss, best_acc, total_loss, hessian_loss, conv1_hessian, conv2_hessian, fc1_hessian, fc2_hessian]
            record_savefile = '{}.csv'.format(store_paths['prefix'])
            record_savefile = os.path.join(store_paths['result'], record_savefile)
            if epoch < 2 and os.path.exists(record_savefile): os.remove(record_savefile)
            csv_logger(record_row, record_savefile)

            # store the model when the attack is successful
            model_savefile = '{}.pth'.format(store_paths['prefix'])
            model_savefile = os.path.join(store_paths['model'], model_savefile)
            if valid_acc > best_acc:
                torch.save(network.state_dict(), model_savefile)
                print (' -> cur acc. [{:.4f}] > best acc. [{:.4f}], store the model.\n'.format(valid_acc, best_acc))
                best_acc = valid_acc
    
    # training (only when there is no model)
    elif args.hessianTR:
        for epoch in range(1, args.epoch+1):
            print(f"In epoch : {epoch}")
            train_acc, train_loss, total_loss, hessian_loss = train_with_hessian_trace( \
                args, epoch, network, train_loader, task_loss, optimizer, scheduler)
            valid_acc, valid_loss = valid( \
                args, epoch, network, valid_loader, task_loss, store_paths)
            
            # : store a record
            record_row = [epoch, train_acc, train_loss, valid_acc, valid_loss, best_acc, total_loss, hessian_loss]
            record_savefile = '{}.csv'.format(store_paths['prefix'])
            record_savefile = os.path.join(store_paths['result'], record_savefile)
            if epoch < 2 and os.path.exists(record_savefile): os.remove(record_savefile)
            csv_logger(record_row, record_savefile)

            # store the model when the attack is successful
            model_savefile = '{}.pth'.format(store_paths['prefix'])
            model_savefile = os.path.join(store_paths['model'], model_savefile)
            if valid_acc > best_acc:
                torch.save(network.state_dict(), model_savefile)
                print (' -> cur acc. [{:.4f}] > best acc. [{:.4f}], store the model.\n'.format(valid_acc, best_acc))
                best_acc = valid_acc

    elif args.adaHessian:
        for epoch in range(1, args.epoch+1):
            train_acc, train_loss = train_with_adaHessian( \
                args, epoch, network, train_loader, task_loss, optimizer, scheduler)
            valid_acc, valid_loss = valid( \
                args, epoch, network, valid_loader, task_loss, store_paths)
            
            # : store a record
            record_row = [epoch, train_acc, train_loss, valid_acc, valid_loss, best_acc]
            record_savefile = '{}.csv'.format(store_paths['prefix'])
            record_savefile = os.path.join(store_paths['result'], record_savefile)
            if epoch < 2 and os.path.exists(record_savefile): os.remove(record_savefile)
            csv_logger(record_row, record_savefile)

            # store the model when the attack is successful
            model_savefile = '{}.pth'.format(store_paths['prefix'])
            model_savefile = os.path.join(store_paths['model'], model_savefile)
            if valid_acc > best_acc:
                torch.save(network.state_dict(), model_savefile)
                print (' -> cur acc. [{:.4f}] > best acc. [{:.4f}], store the model.\n'.format(valid_acc, best_acc))
                best_acc = valid_acc

    else:
        for epoch in range(1, args.epoch+1):

            train_acc, train_loss = train( \
                args, epoch, network, train_loader, task_loss, optimizer, scheduler)
            
            
            current_lr = optimizer.param_groups[0]['lr']
    
            # Print the learning rate
            print(f'Epoch {epoch}, Learning Rate: {current_lr}')
            
            valid_acc, valid_loss = valid( \
                args, epoch, network, valid_loader, task_loss, store_paths)
            
            # : store a record
            record_row = [epoch, train_acc, train_loss, valid_acc, valid_loss, best_acc]
            record_savefile = '{}.csv'.format(store_paths['prefix'])
            record_savefile = os.path.join(store_paths['result'], record_savefile)
            if epoch < 2 and os.path.exists(record_savefile): os.remove(record_savefile)
            csv_logger(record_row, record_savefile)

            # store the model when the attack is successful
            model_savefile = '{}_3.pth'.format(store_paths['prefix'])
            model_savefile = os.path.join(store_paths['model'], model_savefile)
            
            if valid_acc > best_acc:
                torch.save(network.state_dict(), model_savefile)
                print (' -> cur acc. [{:.4f}] > best acc. [{:.4f}], store the model.\n'.format(valid_acc, best_acc))
                best_acc = valid_acc

    print (': Done, training')
    # done.




# ------------------------------------------------------------------------------
#   Misc functions
# ------------------------------------------------------------------------------
def csv_logger(data, filepath):
    # write to
    with open(filepath, 'a') as csv_output:
        csv_writer = csv.writer(csv_output)
        csv_writer.writerow(data)
    # done.

def define_vanilla_store_prefixes(args):
    store_folder = os.path.join('vanilla')
    if args.hessian == True:
        store_prefix = '{}_{}_{}_{}_{}_{}_{}_Hess'.format( \
            args.network, args.seed, \
            args.batch_size, args.epoch, args.optimizer, args.lr, args.hessLR)
    elif args.hessianTR == True:
        store_prefix = '{}_{}_{}_{}_{}_{}_{}_HessTR_4'.format( \
            args.network, args.seed, \
            args.batch_size, args.epoch, args.optimizer, args.lr, args.hessLR)
    elif args.adaHessian == True:
        store_prefix = '{}_{}_{}_{}_{}_{}_adaHessian'.format( \
            args.network, args.seed, \
            args.batch_size, args.epoch, args.optimizer, args.lr)
    elif args.hessian_layer_track == True:
        store_prefix = '{}_{}_{}_{}_{}_{}_Hessian_Layer'.format( \
            args.network, args.seed, \
            args.batch_size, args.epoch, args.optimizer, args.lr)
    elif args.hessian == False:
        store_prefix = '{}_{}_{}_{}_{}_{}'.format( \
            args.network, args.seed, \
            args.batch_size, args.epoch, args.optimizer, args.lr)
    else :
        return "wrong keyword for augmentation"
    return store_folder, store_prefix

"""
    Main (to train an ML model in PyTorch)
"""
if __name__ == '__main__':
    parser = argparse.ArgumentParser( \
        description='Train an ML model with PyTorch (vanilla)')

    # system parameters
    parser.add_argument('--seed', type=int, default=215,
                        help='random seed (default: 215)')
    parser.add_argument('--cuda', action='store_true',
                        help='enables CUDA training')
    parser.add_argument('--num-workers', type=int, default=4,
                        help='number of workers (default: 4)')
    parser.add_argument('--pin-memory', action='store_false',
                        help='the data loader copies tensors into CUDA pinned memory')

    # dataset parameters
    parser.add_argument('--dataset', type=str, default='cifar10',
                        help='dataset used to train: cifar10.')
    parser.add_argument('--datapth', type=str, default='',
                        help='dataset location (which uses an processed file)')
    parser.add_argument('--augment', type=str, default='', 
                        help='data augmentation method')

    # model parameters
    parser.add_argument('--network', type=str, default='ConvNet',
                        help='model name (default: ConvNet).')
    parser.add_argument('--trained', type=str, default='',
                        help='pre-trained model filepath.')
    parser.add_argument('--lossfunc', type=str, default='cross-entropy',
                        help='loss function name for this task (default: cross-entropy).')
    parser.add_argument('--classes', type=int, default=10,
                        help='number of classes in the dataset (ex. 10 in CIFAR10).')

    # hyper-parmeters
    parser.add_argument('--batch-size', type=int, default=125,
                        help='input batch size for training (default: 125)')
    parser.add_argument('--epoch', type=int, default=50,
                        help='number of epochs to train (default: 50)')
    parser.add_argument('--optimizer', type=str, default='SGD',
                        help='optimizer used to train (default: SGD)')
    parser.add_argument('--lr', type=float, default=0.1,
                        help='learning rate (default: 0.1)')
    parser.add_argument('--lamda', type=float, default=0.1,
                        help='learning rate (default: 0.1)')
    parser.add_argument('--hessLR', type=float, default=0.1,
                        help='learning rate (default: 0.1)')
    parser.add_argument('--schedule', default=[100, 150, 200], type=int, nargs='+',
                        help='adjust learning rate at each specific epoch')
    
    # hessian parameter
    parser.add_argument('--hessian', action='store_true',
                        help='enables hessian-aware training')
    parser.add_argument('--hessian_layer_track', action='store_true',
                        help='enables hessian-aware training')
    parser.add_argument('--hessianTR', action='store_true',
                        help='enables hessian-aware training')
    parser.add_argument('--adaHessian', action='store_true',
                        help='enables hessian-aware training')
    
    # execution parameters
    args = parser.parse_args()
    print (json.dumps(vars(args), indent=2))

    
    # run the training
    run_training(args)

    # done.
